Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adds count num sequences and tokens metric #346

Merged
merged 2 commits into from
Mar 21, 2024

Conversation

mosheraboh
Copy link
Collaborator

No description provided.

SagiPolaczek
SagiPolaczek previously approved these changes Mar 20, 2024
Copy link
Collaborator

@SagiPolaczek SagiPolaczek left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, Thanks!

I added questions + typing mistakes + artifacts inline.
Doesn't change the logic.

import torch
import numpy as np

from fuse.eval.metrics.metrics_common import MetricPerBatchDefault


class MetricCountSeqAndTokens(MetricPerBatchDefault):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

General question:
I'm not sure counting the sequences and tokens should be defined as metric. I don't have another suggestion it's just sounds weird :)

What do you think of that?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It uses the metric mechanism, and it's ok to me that it just counts some stats.

) -> None:
"""
:param encoder_input: key to the encoder_input
:param ignore_index: token_id to ignore (not to count), typically pad token id
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it should be able to support a list of token ids to ignore. Unless you want to enforce the user to ignore only the PAD one.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I went with just one to be more efficient - typically we would like to just skip the padding.

:param kwargs: additional super class arguments
"""
super().__init__(
seq_num="seq_num", # collect log_probs - output of _count_seq_and_tokens_update
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

obsolete comments in this line and the following one

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍


def _count_seq_and_tokens_update(
batch_dict: dict,
encoder_input_key: str,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

encoder_input_key: Union[str, None] 

or

encoder_input_key: Optional[str]

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's. a must. Why optional?


def _count_seq_and_tokens_compute(
self,
seq_num: List[np.ndarray],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seq_num will be a numpy array such that each entry represents a batch? If so, how often the metrics being calculate? each epoch?

I forgot these :)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

each sub epoch and each entry is a batch.

self,
seq_num: List[np.ndarray],
token_num: List[np.ndarray],
) -> float:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

returns a dict

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

batch_dict: dict,
encoder_input_key: str,
ignore_index: Optional[int] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

-> dict[str, Tensor]

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

@SagiPolaczek
Copy link
Collaborator

SagiPolaczek commented Mar 20, 2024

Last comment, did you try to write a test for it? So we'll have it covered

If time not permits maybe as a card on monday and we'll get to it later

Copy link
Collaborator Author

@mosheraboh mosheraboh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the review and useful comments @SagiPolaczek

) -> None:
"""
:param encoder_input: key to the encoder_input
:param ignore_index: token_id to ignore (not to count), typically pad token id
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I went with just one to be more efficient - typically we would like to just skip the padding.

:param kwargs: additional super class arguments
"""
super().__init__(
seq_num="seq_num", # collect log_probs - output of _count_seq_and_tokens_update
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍


def _count_seq_and_tokens_compute(
self,
seq_num: List[np.ndarray],
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

each sub epoch and each entry is a batch.


def _count_seq_and_tokens_update(
batch_dict: dict,
encoder_input_key: str,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's. a must. Why optional?

batch_dict: dict,
encoder_input_key: str,
ignore_index: Optional[int] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

self,
seq_num: List[np.ndarray],
token_num: List[np.ndarray],
) -> float:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

Copy link
Collaborator

@SagiPolaczek SagiPolaczek left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@SagiPolaczek
Copy link
Collaborator

Merging it to match inner-source code.

@SagiPolaczek SagiPolaczek merged commit 9dc0639 into master Mar 21, 2024
5 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants